{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmark \n",
    "---\n",
    "\n",
    "This notebook compares the decoder performance between Viterbi Decoder and Neural Decoder [1] on Convolution Codes over AWGN Channel.\n",
    "\n",
    "Reference:\n",
    "* [1] Kim, Hyeji, et al. \"Communication Algorithms via Deep Learning.\" ICLR (2018)\n",
    "\n",
    "\n",
    "### TODOs:\n",
    "---\n",
    "\n",
    "* [x] Benchmark Viterbi Decoder \n",
    "* [x] Benchmark Neural Decoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im\n",
    "# a hack to import module from different directory\n",
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pickle\n",
    "import numpy as np\n",
    "import commpy as cp\n",
    "import tensorflow as tf\n",
    "import multiprocessing as mp\n",
    "\n",
    "from deepcom.metrics import BER, BLER         # metrics to benchmark Neural Decoder Model\n",
    "from deepcom.utils import corrupt_signal      # simulate a AWGN Channel\n",
    "from deepcom.dataset import data_genenerator  # data loader for Tensorflow\n",
    "from deepcom.NeuralDecoder import NRSCDecoder # Neural Decoder Model\n",
    "\n",
    "BATCH_SIZE = 500       # depends on size of GPU, should be a factor of number_testing_sequences\n",
    "BLOCK_LEN = 100        # length of a message bits\n",
    "CONSTRAINT_LEN = 3     # num of shifts in Conv. Encoder\n",
    "TRACE_BACK_DEPTH = 15  # (?) a parameter Viterbi Encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Dataset\n",
    "* Dataset should be generated using script `generate_synthetic_dataset.py`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_PATH = '../dataset/rnn_12k_bl100_snr0.dataset'\n",
    "with open(DATASET_PATH, 'rb') as f:\n",
    "    _, _, X_test, Y_test = pickle.load(f)  # ignore training data\n",
    "print('Number of testing sequences {}'.format(len(X_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load pre-trained Neural Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_log = '../logs/BiGRU-2-400::dropout-0.7::epochs7::120k'\n",
    "try:\n",
    "    model_path = os.path.join(experiment_log, 'BiGRU.hdf5')\n",
    "    print(model_path)\n",
    "    model = tf.keras.models.load_model(model_path, custom_objects={'BER': BER, 'BLER': BLER})\n",
    "    model.compile(optimizer='adam', loss='binary_crossentropy')\n",
    "    print('Pre-trained model is loaded.')\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "    raise ValueError('Pre-trained model is not found.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Viterbi/Neural Decoder Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_neural_decoder(noisy_inputs, labels):\n",
    "    \n",
    "    # Set up data generator\n",
    "    Y = np.reshape(labels, (-1, BLOCK_LEN, 1))\n",
    "    X = np.reshape(np.array(noisy_inputs)[:, :2*BLOCK_LEN], (-1, BLOCK_LEN, 2))\n",
    "    test_set = data_genenerator(X, Y, BATCH_SIZE, shuffle=False)\n",
    "    \n",
    "    # Make predictions on batch\n",
    "    decoded_bits = model.predict(\n",
    "        test_set.make_one_shot_iterator(), \n",
    "        steps=len(Y) // BATCH_SIZE)\n",
    "    \n",
    "    # Compute hamming distances\n",
    "    original_bits = np.reshape(Y, (-1, BLOCK_LEN)).astype(int)\n",
    "    decoded_bits =  np.reshape(np.round(decoded_bits), (-1, BLOCK_LEN)).astype(int)\n",
    "    hamming_dist = np.not_equal(original_bits, decoded_bits)\n",
    "    \n",
    "    return np.sum(hamming_dist, axis=1)\n",
    "\n",
    "def benchmark_viterbi(message_bits, noisy_bits, sigma):\n",
    "    # make fair comparison between (100, 204) convolutional code and RNN decoder\n",
    "    # Reference: Author's code\n",
    "    noisy_bits[-2*int(M):] = 0\n",
    "    # Viterbi Decoder on Conv. Code\n",
    "    decoded_bits = cp.channelcoding.viterbi_decode(\n",
    "        coded_bits=noisy_bits.astype(float), \n",
    "        trellis=trellis,\n",
    "        tb_depth=TRACE_BACK_DEPTH,\n",
    "        decoding_type='unquantized')\n",
    "    # Number of bit errors (hamming distance)\n",
    "    hamming_dist = cp.utilities.hamming_dist(\n",
    "        message_bits.astype(int),\n",
    "        decoded_bits[:-int(M)])\n",
    "    return hamming_dist\n",
    "\n",
    "\n",
    "from commpy.channelcoding import Trellis\n",
    "\n",
    "#  Generator Matrix (octal representation)\n",
    "G = np.array([[0o7, 0o5]]) \n",
    "M = np.array([CONSTRAINT_LEN - 1])  # Number of delay elements in the convolutional encoder\n",
    "trellis = Trellis(M, G, feedback=0o7, code_type='rsc')\n",
    "\n",
    "# #################################################################\n",
    "# For every SNR_db, we generates new noisy signals\n",
    "# for fair comparision.\n",
    "# #################################################################\n",
    "def generate_noisy_input(message_bits, trellis, sigma):\n",
    "    # Encode message bit\n",
    "    coded_bits = cp.channelcoding.conv_encode(message_bits, trellis)\n",
    "    # Corrupt message on BAWGN Channel\n",
    "    coded_bits = corrupt_signal(coded_bits, noise_type='awgn', sigma=sigma)\n",
    "    return coded_bits, message_bits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "viterbiBERs, viterbiBLERs = [], []\n",
    "neuralBERs, neuralBLERs = [], []\n",
    "\n",
    "pool = mp.Pool(processes=mp.cpu_count())\n",
    "labels = np.reshape(Y_test, (-1, BLOCK_LEN)).astype(int)\n",
    "try: \n",
    "    SNRs  = np.linspace(0, 7.0, 8)\n",
    "    for snr in SNRs:\n",
    "        snr_linear = snr + 10 * np.log10(1./2.)\n",
    "        sigma = np.sqrt(1. / (2. * 10 **(snr_linear / 10.)))\n",
    "        print('[SNR]={:.2f}'.format(snr))\n",
    "        \n",
    "        # Generates new noisy signals\n",
    "        result = pool.starmap(generate_noisy_input,  [(msg_bits, trellis, sigma) for msg_bits in labels])\n",
    "        X, Y =  zip(*result)\n",
    "        \n",
    "        # #################################################################\n",
    "        # BENCHMARK NEURAL DECODER \n",
    "        # #################################################################\n",
    "        hamm_dists = benchmark_neural_decoder(X, Y)\n",
    "        nn_ber = sum(hamm_dists) / np.product(np.shape(X))\n",
    "        nn_bler = np.count_nonzero(hamm_dists) / len(X)\n",
    "\n",
    "        neuralBERs.append(nn_ber)\n",
    "        neuralBLERs.append(nn_bler)            \n",
    "        print('\\tNeural Decoder:  [BER]={:5.7f} [BLER]={:5.3f} '.format(nn_ber, nn_bler)) \n",
    "\n",
    "        # #################################################################\n",
    "        # BENCHMARK VITERBI DECODER \n",
    "        # #################################################################\n",
    "        hamm_dists = pool.starmap(benchmark_viterbi, [(y, x, sigma) for x, y in zip(X, Y)])\n",
    "        # Bit error rate\n",
    "        ber = sum(hamm_dists) / np.product(np.shape(Y))\n",
    "        # Block error rate\n",
    "        bler = np.count_nonzero(hamm_dists) / len(Y)\n",
    "        viterbiBERs.append(ber)\n",
    "        viterbiBLERs.append(bler)\n",
    "        print('\\tViterbi Decoder: [BER]={:5.7f} [BLER]={:5.3f} '.format(ber, bler)) \n",
    "        \n",
    "except Exception as e:\n",
    "    print(e)\n",
    "finally:\n",
    "    pool.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "plt.semilogy(SNRs, neuralBERs, '-vr')\n",
    "plt.semilogy(SNRs, viterbiBERs, '-s')\n",
    "plt.title('N = 10,000 || Block Length = 1000 || Data rate = 1/2', fontsize=14)\n",
    "plt.legend(['N-RSC (SNR_train=0 dB)', 'Viterbi'], fontsize=16)\n",
    "plt.xlabel('SNR', fontsize=16)\n",
    "plt.xlim(xmin=SNRs[0], xmax=SNRs[-1])  # this line\n",
    "plt.ylabel('BER', fontsize=16)\n",
    "plt.grid(True, which='both')\n",
    "plt.savefig('result_ber_block_length_1000_snr0.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 8))\n",
    "plt.semilogy(SNRs, neuralBLERs, '-vr')\n",
    "plt.semilogy(SNRs, viterbiBLERs, '-s')\n",
    "plt.ylabel('BLER', fontsize=16)\n",
    "plt.xlabel('SNR', fontsize=16)\n",
    "plt.title('N = 10,000 || Block Length = 1000 || Data rate = 1/2', fontsize=14)\n",
    "plt.legend(['N-RSC (SNR_train=0 dB)', 'Viterbi'], fontsize=16)\n",
    "\n",
    "plt.xlim(xmin=SNRs[0], xmax=SNRs[-1])  # this line\n",
    "plt.grid(True, which='both')\n",
    "plt.savefig('result_bler_block_length_1000_snr0.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
